/*
 * Copyright 2018-2019 Baidu, Inc. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */
package com.baidubce.services.iothisk.device;

import static com.baidubce.services.iothisk.device.utils.CounterUtils.validCounter;
import static com.baidubce.services.iothisk.device.utils.Message.INVALID_DEVICE_COMPANY;
import static com.baidubce.services.iothisk.device.utils.Message.INVALID_DEVICE_TYPE;
import static com.baidubce.services.iothisk.device.utils.Message.INVALID_SERIAL_NUMBER;
import static com.baidubce.services.iothisk.device.utils.Message.NULL_CONTRACT;
import static com.baidubce.services.iothisk.device.utils.Message.NULL_DEVICE_SDK_TYPE;
import static com.baidubce.services.iothisk.device.utils.Message.NULL_SERIAL_NUMBER;
import static com.baidubce.services.iothisk.device.utils.Message.UNKNOWN_DEVICE_SDK_TYPE;
import static com.google.common.base.Preconditions.checkNotNull;
import static org.bouncycastle.jce.provider.BouncyCastleProvider.PROVIDER_NAME;

import java.security.Security;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import com.baidubce.services.iothisk.device.model.ActiveMessage;
import com.baidubce.services.iothisk.device.model.CipherMessage;
import com.baidubce.services.iothisk.device.model.Device;
import com.baidubce.services.iothisk.device.model.DeviceKey;
import com.baidubce.services.iothisk.device.model.PlainMessage;
import com.baidubce.services.iothisk.device.seplatform.SecureElement;
import com.baidubce.services.iothisk.device.seplatform.SecureElementFactory;

/**
 * Provides the hisk device with the local secure element ability.
 * Currently, only baidu mbed_akey secure element is supported.
 */
public class IotHiskDevice {

    /**
     * Provided device basic info
     */
    private final Device device;

    /**
     * Device key info, generated by provided device
     */
    private final DeviceKey deviceKey;

    /**
     * Current device client counter, initially default value is 0.
     */
    private long currentCounter = 0L;

    private static final int SERIAL_NUMBER_MAX_LENGTH = 32;
    private static final int DEVICE_COMPANY_MAX_LENGTH = 32;
    private static final int DEVICE_TYPE_MAX_LENGTH = 48;
    private static final int DEVICE_COMPANY_AND_TYPE_MIN_LENTH = 3;
    private static final Pattern namePattern = Pattern.compile("[\u4e00-\u9fa5\\w-]+");

    /**
     * Loads bouncy castle security provider
     */
    static {
        if (Security.getProvider(PROVIDER_NAME) == null) {
            Security.addProvider(new BouncyCastleProvider());
        }
    }

    /**
     * Constructs a new hisk device client using the provided hisk device info.
     *
     * @param device specified hisk basic device info, corresponding to the contract created in hisk cloud service.
     */
    public IotHiskDevice(Device device) {
        checkNotNull(device, NULL_CONTRACT);
        validDevice(device);

        this.device = device;
        this.deviceKey = generateDeviceKey(device);
    }

    /**
     * Get unique device id in ascii encoding
     *
     * @return unique device id
     */
    public String getDeviceId() {
        return deviceKey.getDeviceId();
    }

    /**
     * Get device active data generated by device secure element
     *
     * @return successful device active data in byte array, otherwise an exception will be thrown
     */
    public byte[] getActiveData() {
        ActiveMessage activeMessage = new ActiveMessage();
        activeMessage.setDeviceId(deviceKey.getDeviceId());
        activeMessage.setSdkType(device.getDeviceSdkType());
        activeMessage.setSeId(deviceKey.getSeId());

        return encrypt(activeMessage.getBytes());
    }

    /**
     * Encrypt message with device secure element.
     *
     * @param message plain message in byte array
     * @return successful cipher message in byte array, otherwise an exception will be thrown
     */
    public byte[] encrypt(byte[] message) {
        PlainMessage plainMessage = new PlainMessage(getCurrentCounter(), message);
        return deviceKey.getSe().encryptThenSign(plainMessage).getBytes();
    }

    /**
     * Decrypt cipher with device secure element.
     *
     * @param cipherMessage cipher message in byte array
     * @return successful plain message in byte array, otherwise an exception will be thrown
     */
    public byte[] decrypt(byte[] cipherMessage) {
        CipherMessage cipher = deviceKey.getSe().parseCipherMessage(cipherMessage);
        PlainMessage plainMessage = deviceKey.getSe().verifyThenDecrypt(cipher);
        validCounter(device.getDeviceSdkType(), plainMessage.getCounter(), getCurrentCounter());
        setCurrentCounter(plainMessage.getCounter());

        return plainMessage.getMessage();
    }

    /**
     * Get device client current counter. Counter will be used to check message to resist replay attack.
     *
     * @return current device counter.
     */
    public long getCurrentCounter() {
        switch (device.getDeviceSdkType()) {
            case NONE_RTC:
                return currentCounter;
            case RTC:
                return System.currentTimeMillis() / 1000;
            case NONE_COUNTER:
                // no counter check, anything is ok
                return System.currentTimeMillis() / 1000;
            default:
                throw new IllegalArgumentException(UNKNOWN_DEVICE_SDK_TYPE);
        }
    }

    /**
     * Set device counter by user specified counter.
     *
     * @param counter specified counter
     */
    public void setCurrentCounter(long counter) {
        switch (device.getDeviceSdkType()) {
            case NONE_RTC:
                this.currentCounter = counter;
                break;
            case NONE_COUNTER:
            case RTC:
                // no thing to do
                break;
            default:
                throw new IllegalArgumentException(UNKNOWN_DEVICE_SDK_TYPE);
        }
    }

    /**
     * Generate local hisk device key, which will be used during message encryption and decryption.
     * Device key contains security element information as well.
     *
     * @param device specified hisk basic device info, corresponding to the contract created in hisk cloud service.
     * @return local hisk device key
     */
    private DeviceKey generateDeviceKey(Device device) {
        DeviceKey deviceKey = new DeviceKey();
        deviceKey.setSeId(device.getSerialNumber());
        deviceKey.setSeType(device.getType());

        SecureElement se = SecureElementFactory.createSe(device, deviceKey);
        deviceKey.setDeviceId(se.generateId());
        deviceKey.setSe(se);
        return deviceKey;
    }

    private static void validDevice(Device device) {
        checkNotNull(device.getSerialNumber(), NULL_SERIAL_NUMBER);
        checkNotNull(device.getDeviceSdkType(), NULL_DEVICE_SDK_TYPE);

        validateName(device.getDeviceCompany(), DEVICE_COMPANY_MAX_LENGTH, INVALID_DEVICE_COMPANY);
        validateName(device.getDeviceType(), DEVICE_TYPE_MAX_LENGTH, INVALID_DEVICE_TYPE);
        validateSerialNumber(device.getSerialNumber());
    }

    private static void validateName(String name, int maxLength, String errorMessage) {
        if (StringUtils.isBlank(name) || StringUtils.length(name) < DEVICE_COMPANY_AND_TYPE_MIN_LENTH
                || StringUtils.length(name) > maxLength
                || !namePattern.matcher(name).matches()) {
            throw new IllegalArgumentException(errorMessage);
        }
    }

    private static void validateSerialNumber(String serialNumber) {
        if (StringUtils.isBlank(serialNumber) || StringUtils.length(serialNumber) > SERIAL_NUMBER_MAX_LENGTH) {
            throw new IllegalArgumentException(INVALID_SERIAL_NUMBER);
        }
    }

}
